/*
 * Copyright (c) 2019-2024, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "rtp_llm/cpp/cuda/cuda_type_utils.cuh"
#include "rtp_llm/cpp/cuda/launch_utils.h"
#include "rtp_llm/cpp/cuda/cuda_host_utils.h"
#include "no_aux_tc_kernels.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

namespace cg = cooperative_groups;

namespace rtp_llm {
constexpr unsigned FULL_WARP_MASK      = 0xffffffff;
constexpr int32_t  WARP_SIZE           = 32;
constexpr int32_t  BLOCK_SIZE          = 512;
constexpr int32_t  NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;

namespace warp_topk {

template<int size, typename T>
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
    if (len == 0) {
        return 0;
    }
    return ((len - 1) / size + 1) * size;
}

template<typename T>
constexpr __host__ __device__ bool isPowerOf2(T v) {
    return (v && !(v & (v - 1)));
}

template<bool greater, typename T>
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
    return (val > baseline && greater) || (val < baseline && !greater);
}

template<bool greater, typename T, typename idxT>
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, idxT baseline_index) {
    bool res = (val > baseline && greater) || (val < baseline && !greater);
    if (val == baseline) {
        res = (index < baseline_index && greater) || (index < baseline_index && !greater);
    }
    return res;
}

template<typename T, typename idxT>
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
    int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
    int64_t n          = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
    return max(cache_topk, round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}

template<int size, bool ascending, bool reverse, typename T, typename idxT, bool is_stable>
struct BitonicMerge {
    // input should be a bitonic sequence, and sort it to be a monotonic sequence
    __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) {
        static_assert(isPowerOf2(size));
        static_assert(size >= 2 * WARP_SIZE);
        constexpr int arr_len = size / WARP_SIZE;

        constexpr int stride = arr_len / 2;
        for (int i = 0; i < stride; ++i) {
            int const other_i   = i + stride;
            T&        val       = val_arr[i];
            T&        other_val = val_arr[other_i];
            bool      is_better;
            if constexpr (is_stable) {
                is_better = is_better_than<ascending>(val, other_val, idx_arr[i], idx_arr[other_i]);
            } else {
                is_better = is_better_than<ascending>(val, other_val);
            }

            if (is_better) {
                T tmp     = val;
                val       = other_val;
                other_val = tmp;

                idxT tmp2        = idx_arr[i];
                idx_arr[i]       = idx_arr[other_i];
                idx_arr[other_i] = tmp2;
            }
        }

        BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(val_arr, idx_arr);
        BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(val_arr + arr_len / 2,
                                                                              idx_arr + arr_len / 2);
    }
};

template<int size, bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort {
    __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) {
        static_assert(isPowerOf2(size));
        static_assert(size >= 2 * WARP_SIZE);
        constexpr int arr_len = size / WARP_SIZE;

        BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
        BitonicSort<size / 2, false, T, idxT, is_stable>::sort(val_arr + arr_len / 2, idx_arr + arr_len / 2);
        BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(val_arr, idx_arr);
    }
};

template<bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort<32, ascending, T, idxT, is_stable> {
    __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) {
        int const lane = threadIdx.x % WARP_SIZE;

        // ascending doesn't matter before merging since all we need is a bitonic sequence
        for (int stage = 0; stage < 4; ++stage) {
            for (int stride = (1 << stage); stride > 0; stride /= 2) {
                bool reverse   = (lane >> stage) & 2;
                bool is_second = lane & stride;

                T    other     = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
                idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);

                bool is_better;
                if constexpr (is_stable) {
                    if constexpr (ascending) {
                        is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr < other_idx)))
                                    != (reverse != is_second);
                    } else {
                        is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr > other_idx)))
                                    != (reverse != is_second);
                    }
                } else {
                    // is_better = (*val_arr != other) && is_better_than(*val_arr, other, (reverse == is_second));
                    is_better = (*val_arr != other && (*val_arr > other) != (reverse != is_second));
                }
                if (is_better) {
                    *val_arr = other;
                    *idx_arr = other_idx;
                }
            }
        }

        BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, idx_arr);
    }
};

template<bool ascending, bool reverse, typename T, typename idxT, bool is_stable>
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
    __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) {
        int const lane = threadIdx.x % WARP_SIZE;
        for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
            bool  is_second = lane & stride;
            T&    val       = *val_arr;
            T     other     = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
            idxT& idx       = *idx_arr;
            idxT  other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);

            bool is_better;
            if constexpr (is_stable) {
                if constexpr (ascending) {
                    is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr < other_idx)))
                                == (reverse != is_second);  // for min
                } else {
                    is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr > other_idx)))
                                == (reverse != is_second);  // for max
                }
            } else {
                // is_better = (val != other) && (is_better_than(val, other, (reverse != is_second)));
                is_better = (val != other && ((val > other) == (ascending != is_second)));
            }

            if (is_better) {
                val = other;
                idx = other_idx;
            }
        }
    }
};

template<int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSort {
public:
    __device__ WarpSort(idxT k, T dummy): lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
        static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));

        for (int i = 0; i < max_arr_len_; ++i) {
            val_arr_[i] = dummy_;
            idx_arr_[i] = 0;
        }
    }

    // load and merge k sorted values
    __device__ void load_sorted(T const* __restrict__ in, idxT const* __restrict__ in_idx, idxT start) {
        idxT idx = start + WARP_SIZE - 1 - lane_;
        for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
            if (idx < start + k_) {
                T    t = in[idx];
                bool is_better;
                if constexpr (is_stable) {
                    is_better = is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
                } else {
                    is_better = is_better_than<greater>(t, val_arr_[i]);
                }
                if (is_better) {
                    val_arr_[i] = t;
                    idx_arr_[i] = in_idx[idx];
                }
            }
        }

        BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(val_arr_, idx_arr_);
    }

    __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
        for (int i = 0; i < max_arr_len_; ++i) {
            idxT out_i = i * WARP_SIZE + lane_;
            if (out_i < k_) {
                out[out_i]     = val_arr_[i];
                out_idx[out_i] = idx_arr_[i];
            }
        }
    }

    __device__ void dumpIdx(idxT* __restrict__ out_idx) const {
        for (int i = 0; i < max_arr_len_; ++i) {
            idxT out_i = i * WARP_SIZE + lane_;
            if (out_i < k_) {
                out_idx[out_i] = idx_arr_[i];
            }
        }
    }

protected:
    static constexpr int max_arr_len_ = capacity / WARP_SIZE;

    T    val_arr_[max_arr_len_];
    idxT idx_arr_[max_arr_len_];

    int const  lane_;
    idxT const k_;
    T const    dummy_;

};  // end class WarpSort

template<int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSelect: public WarpSort<capacity, greater, T, idxT, is_stable> {
public:
    __device__ WarpSelect(idxT k, T dummy):
        WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy), k_th_(dummy), k_th_lane_((k - 1) % WARP_SIZE) {

        extern __shared__ char smem_buf[];  // extern __shared__ T smem_buf[];

        int const num_of_warp = blockDim.x / WARP_SIZE;
        int const warp_id     = threadIdx.x / WARP_SIZE;
        val_smem_             = reinterpret_cast<T*>(smem_buf);
        val_smem_ += warp_id * WARP_SIZE;
        idx_smem_ =
            reinterpret_cast<idxT*>(smem_buf + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
        idx_smem_ += warp_id * WARP_SIZE;
    }

    __device__ void add(T const* in, idxT start, idxT end) {
        idxT const end_for_fullwarp = round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
        for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
            T val = (i < end) ? in[i] : dummy_;
            add(val, i);
        }
    }

    __device__ void add(T val, idxT idx) {
        bool do_add;
        if constexpr (is_stable) {
            do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
        } else {
            do_add = is_better_than<greater>(val, k_th_);
        }

        uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
        if (mask == 0) {
            return;
        }

        int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
        if (do_add && pos < WARP_SIZE) {
            val_smem_[pos] = val;
            idx_smem_[pos] = idx;
            do_add         = false;
        }
        smem_buf_len_ += __popc(mask);
        if (smem_buf_len_ >= WARP_SIZE) {
            __syncwarp();
            merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
            smem_buf_len_ -= WARP_SIZE;
        }
        if (do_add) {
            pos -= WARP_SIZE;
            val_smem_[pos] = val;
            idx_smem_[pos] = idx;
        }
        __syncwarp();
    }

    __device__ void done() {
        if (smem_buf_len_) {
            T    val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
            idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
            merge_buf_(val, idx);
        }

        // after done(), smem is used for merging results among warps
        __syncthreads();
    }

private:
    __device__ void set_k_th_() {
        k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
        if constexpr (is_stable) {
            k_th_idx_ = __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
        }
    }

    __device__ void merge_buf_(T val, idxT idx) {
        BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);

        T& old = val_arr_[max_arr_len_ - 1];

        bool is_better;
        if constexpr (is_stable) {
            is_better = is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
        } else {
            is_better = is_better_than<greater>(val, old);
        }

        if (is_better) {
            old                        = val;
            idx_arr_[max_arr_len_ - 1] = idx;
        }

        BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(val_arr_, idx_arr_);

        set_k_th_();
    }

    using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
    using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
    using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
    using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
    using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
    using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;

    T*    val_smem_;
    idxT* idx_smem_;
    int   smem_buf_len_ = 0;

    T         k_th_;
    idxT      k_th_idx_;
    int const k_th_lane_;
};  // end class WarpSelect
}  // namespace warp_topk

template<typename T>
__device__ void topk_with_k2(T*                               output,
                             T const*                         input,
                             cg::thread_block_tile<32> const& tile,
                             int32_t const                    lane_id,
                             int const                        num_experts_per_group) {
    // Get the top2 per thread
    T largest        = -INFINITY;
    T second_largest = -INFINITY;

    if (num_experts_per_group > WARP_SIZE) {
        for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
            T value = input[i];
            if (value > largest) {
                second_largest = largest;
                largest        = value;
            } else if (value > second_largest) {
                second_largest = value;
            }
        }
    } else {
        for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
            largest = input[i];
        }
    }

    __syncwarp();  // Ensure all threads have valid data before reduction
    // Get the top2 warpwise
    T max1 = cg::reduce(tile, largest, cg::greater<T>());

    T    max2          = max1;
    bool equal_to_max1 = (max1 == largest);

    int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));

    if (count_max1 == 1) {
        largest = (largest == max1) ? second_largest : largest;
        max2    = cg::reduce(tile, largest, cg::greater<T>());
    }

    if (lane_id == 0) {
        *output = max1 + max2;
    }
}

template<typename T>
__global__ void topk_with_k2_kernel(T*            output,
                                    T*            input,
                                    int64_t const num_tokens,
                                    int64_t const num_cases,
                                    int64_t const n_group,
                                    int64_t const num_experts_per_group) {

    int32_t warp_id = threadIdx.x / WARP_SIZE;
    int32_t lane_id = threadIdx.x % WARP_SIZE;

    int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
    if (case_id < num_cases) {
        input += case_id * num_experts_per_group;
        output += case_id;

        cg::thread_block          block = cg::this_thread_block();
        cg::thread_block_tile<32> tile  = cg::tiled_partition<32>(block);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
        asm volatile("griddepcontrol.wait;");
#endif
        topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
    }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
    asm volatile("griddepcontrol.launch_dependents;");
#endif
}

template<typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel(T*            scores,
                                              T const*      group_scores,
                                              T*            topk_values,
                                              IdxT*         topk_indices,
                                              T*            scores_with_bias,
                                              int64_t const num_tokens,
                                              int64_t const n_group,
                                              int64_t const topk_group,
                                              int64_t const topk,
                                              int64_t const num_experts,
                                              int64_t const num_experts_per_group,
                                              int           norm_node,
                                              double        routed_scaling_factor) {
    // TODO: norm_node

    int32_t warp_id = threadIdx.x / WARP_SIZE;
    int32_t lane_id = threadIdx.x % WARP_SIZE;
    int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;  // one per token
    scores_with_bias += case_id * num_experts;
    scores += case_id * num_experts;
    group_scores += case_id * n_group;
    topk_values += case_id * topk;
    topk_indices += case_id * topk;

    int32_t align_num_experts_per_group = warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);

    cg::thread_block          block = cg::this_thread_block();
    cg::thread_block_tile<32> tile  = cg::tiled_partition<32>(block);

    extern __shared__ char smem_buf[];  // NOTE: reuse the shared memory here to store the target topk idx
    int32_t*               s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
    T* s_topk_value                   = reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + warp_id * topk;
    s_topk_idx += warp_id * topk;

    T       value            = cuda::std::numeric_limits<T>::lowest();
    T       topk_group_value = cuda::std::numeric_limits<T>::lowest();
    int32_t num_equalto_topkth_group;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
    asm volatile("griddepcontrol.wait;");  // I think all prolog can be put before acqbulk because it's ptr arithmetic
#endif

    if (case_id < num_tokens) {
        // calculate group_idx
        int32_t target_num_min = WARP_SIZE - n_group + topk_group;
        if (lane_id < n_group
            && (isfinite(
                cuda_cast<float, T>(group_scores[lane_id]))))  // The check is necessary to avoid abnormal input
        {
            value = group_scores[lane_id];
        }

        int count_equal_to_top_value     = WARP_SIZE - n_group;
        int pre_count_equal_to_top_value = 0;
        // Use loop to find the largset top_group
        while (count_equal_to_top_value < target_num_min) {
            __syncwarp();  // Ensure all threads have valid data before reduction
            topk_group_value = cg::reduce(tile, value, cg::greater<T>());
            if (value == topk_group_value) {
                value = cuda::std::numeric_limits<T>::lowest();
            }
            pre_count_equal_to_top_value = count_equal_to_top_value;
            count_equal_to_top_value =
                __popc(__ballot_sync(FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::lowest())));
        }
        num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
    }
    __syncthreads();

    warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t, /* is_stable */ true> queue(
        (int32_t)topk, -INFINITY);

    int  count_equalto_topkth_group = 0;
    bool if_proceed_next_topk       = (topk_group_value != cuda::std::numeric_limits<T>::lowest());
    if (case_id < num_tokens && if_proceed_next_topk) {
        for (int i_group = 0; i_group < n_group; i_group++) {
            if ((group_scores[i_group] > topk_group_value)
                || ((group_scores[i_group] == topk_group_value)
                    && (count_equalto_topkth_group < num_equalto_topkth_group))) {
                int32_t offset = i_group * num_experts_per_group;
                for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) {
                    T candidates =
                        (i < num_experts_per_group) && isfinite(cuda_cast<float, T>(scores_with_bias[offset + i])) ?
                            scores_with_bias[offset + i] :
                            cuda::std::numeric_limits<T>::lowest();
                    queue.add(candidates, offset + i);
                }
                if (group_scores[i_group] == topk_group_value) {
                    count_equalto_topkth_group++;
                }
            }
        }
        queue.done();
        __syncwarp();
        // Get the topk_idx
        queue.dumpIdx(s_topk_idx);
        __syncwarp();
    }

    // Load the valid score value
    // Calculate the summation
    float topk_sum = norm_node == 0 ? 1.0 : 1e-20;
    if (case_id < num_tokens && if_proceed_next_topk) {
        for (int i = lane_id; i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk); i += WARP_SIZE) {
            T value = i < topk ? scores[s_topk_idx[i]] : cuda_cast<T, float>(0.0f);  // Load the valid value of expert
            if (i < topk) {
                s_topk_value[i] = value;
            }
            if (norm_node == 1) {
                topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
            }
        }
    }
    __syncthreads();

    if (case_id < num_tokens) {
        if (if_proceed_next_topk) {
            for (int i = lane_id; i < topk; i += WARP_SIZE) {
                float value     = cuda_cast<float, T>(s_topk_value[i]) / topk_sum * routed_scaling_factor;
                topk_indices[i] = s_topk_idx[i];
                topk_values[i]  = cuda_cast<T, float>(value);
            }
        } else {
            for (int i = lane_id; i < topk; i += WARP_SIZE) {
                topk_indices[i] = i;
                topk_values[i]  = cuda_cast<T, float>(1.0f / topk);
            }
        }
        // Note: when if_proceed_next_topk==false, choose the first 8 experts as the default result.
        //@TODO: check if this default strategy is acceptable. Might need to leave it as nan array.
    }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
    asm volatile("griddepcontrol.launch_dependents;");
#endif
}

template<typename T, typename IdxT>
void invokeNoAuxTc(T*                 scores,
                   T*                 group_scores,
                   T*                 topk_values,
                   IdxT*              topk_indices,
                   T*                 scores_with_bias,
                   int64_t const      num_tokens,
                   int64_t const      num_experts,
                   int64_t const      n_group,
                   int64_t const      topk_group,
                   int64_t const      topk,
                   int                norm_node,
                   double const       routed_scaling_factor,
                   cudaStream_t const stream) {
    int64_t num_cases               = num_tokens * n_group;
    int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
    LAUNCH_KERNEL_WITH_PDL((topk_with_k2_kernel<T>),
                           topk_with_k2_num_blocks,
                           BLOCK_SIZE,
                           0,
                           stream,
                           group_scores,
                           scores_with_bias,
                           num_tokens,
                           num_cases,
                           n_group,
                           num_experts / n_group);
    check_cuda_error();

    int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
    size_t  dynamic_smem_in_bytes = warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK, topk);
    LAUNCH_KERNEL_WITH_PDL((group_idx_and_topk_idx_kernel<T, IdxT>),
                           topk_with_k_group_num_blocks,
                           BLOCK_SIZE,
                           dynamic_smem_in_bytes,
                           stream,
                           scores,
                           group_scores,
                           topk_values,
                           topk_indices,
                           scores_with_bias,
                           num_tokens,
                           n_group,
                           topk_group,
                           topk,
                           num_experts,
                           num_experts / n_group,
                           norm_node,
                           routed_scaling_factor);
    check_cuda_error();

#undef LAUNCH_KERNEL
}

#define INSTANTIATE_NOAUX_TC(T, IdxT)                                                                                  \
    template void invokeNoAuxTc<T, IdxT>(T * scores,                                                                   \
                                         T * group_scores,                                                             \
                                         T * topk_values,                                                              \
                                         IdxT * topk_indices,                                                          \
                                         T * scores_with_bias,                                                         \
                                         int64_t const      num_tokens,                                                \
                                         int64_t const      num_experts,                                               \
                                         int64_t const      n_group,                                                   \
                                         int64_t const      topk_group,                                                \
                                         int64_t const      topk,                                                      \
                                         int                norm_node,                                                 \
                                         double const       routed_scaling_factor,                                     \
                                         cudaStream_t const stream);

INSTANTIATE_NOAUX_TC(float, int32_t);
INSTANTIATE_NOAUX_TC(float, int64_t);
// INSTANTIATE_NOAUX_TC(half, int32_t);
// #ifdef ENABLE_BF16
// INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
// #endif

}  // namespace rtp_llm
